-
Notifications
You must be signed in to change notification settings - Fork 162
Disable KD mode from saving problematic state #320
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
Auto-sync is disabled for draft pull requests in this repository. Workflows must be run manually. Contributors can view more details about this message here. |
WalkthroughUpdated docs, examples, configs, and tests to broaden accepted Changes
Sequence Diagram(s)sequenceDiagram
autonumber
participant T as KDTrainer
participant KD as KnowledgeDistillationModeDescriptor
participant M as Model (student/exported)
participant S as ModeloptStateManager
Note over T,KD: Pre-save reset then full state save
T->>KD: update_for_save()
activate KD
KD->>M: _reset_kd_state_config (teacher_model -> marker, criterion -> Loss(), loss_balancer -> None)
deactivate KD
T->>S: collect modelopt_state (unfiltered)
T->>T: save model + modelopt_state (kd entries preserved)
sequenceDiagram
autonumber
participant U as User
participant S as ModeloptStateManager
participant M as Model (on disk)
participant KD as KnowledgeDistillationModeDescriptor
Note over U,S: Restore returns raw model (no DistillationModel)
U->>S: restore(checkpoint)
S-->>U: return model instance (no DistillationModel)
U->>KD: (optional) update_for_new_mode()
KD->>M: _reset_kd_state_config (ensure KD config consistent)
Estimated code review effort🎯 4 (Complex) | ⏱️ ~60 minutes Poem
Pre-merge checks and finishing touches❌ Failed checks (1 warning)
✅ Passed checks (2 passed)
✨ Finishing touches
🧪 Generate unit tests
Tip 👮 Agentic pre-merge checks are now available in preview!Pro plan users can now enable pre-merge checks in their settings to enforce checklists before merging PRs.
Please see the documentation for more information. Example: reviews:
pre_merge_checks:
custom_checks:
- name: "Undocumented Breaking Changes"
mode: "warning"
instructions: |
Pass/fail criteria: All breaking changes to public APIs, CLI flags, environment variables, configuration keys, database schemas, or HTTP/GraphQL endpoints must be documented in the "Breaking Change" section of the PR description and in CHANGELOG.md. Exclude purely internal or private changes (e.g., code not exported from package entry points or explicitly marked as internal). Please share your feedback with us on this Discord post. Thanks for using CodeRabbit! It's free for OSS, and your support helps us grow. If you like it, consider giving us a shout-out. Comment |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Actionable comments posted: 0
Caution
Some comments are outside the diff and can’t be posted inline due to platform limitations.
⚠️ Outside diff range comments (2)
docs/source/guides/4_distillation.rst (2)
42-55
: Example uses undefined variable teacher_model.The snippet sets teacher_model in the config but never defines it.
Apply this diff to define a simple teacher for the example:
from torchvision.models import resnet50 @@ # User-defined model (student) model = resnet50() # Configure and convert for distillation distillation_config = { - # `teacher_model` is a model, model class, callable, or a tuple. + # A simple example teacher; in practice use a stronger model. + "teacher_model": resnet50(), + # `teacher_model` can be a model, model class, callable, or a tuple. - # If a tuple, it must be of the form (model_cls_or_callable,) or + # If a tuple, it must be of the form (model_cls_or_callable,) or # (model_cls_or_callable, args) or (model_cls_or_callable, args, kwargs). - "teacher_model": teacher_model, "criterion": mtd.LogitsDistillationLoss(), "loss_balancer": mtd.StaticLossBalancer(), }
156-162
: Typo in API alias: atd → mtd.The code won’t run as written.
Apply this diff:
- distillation_model = atd.convert(student_model, mode=[("kd_loss", distillation_config)]) + distillation_model = mtd.convert(student_model, mode=[("kd_loss", distillation_config)])
🧹 Nitpick comments (12)
examples/llm_distill/README.md (4)
42-45
: Example may OOM on 70B teacher; suggest device_map to shard or offload.Loading a 70B model on a single device will likely OOM. Recommend showing a safer pattern in the snippet.
Apply this diff to the example:
-# Define student & teacher -student_model = AutoModelForCausalLM.from_pretrained("meta-llama/Llama-3.1-8B-Instruct") -teacher_model = AutoModelForCausalLM.from_pretrained("meta-llama/Llama-3.1-70B-Instruct") +# Define student & teacher +student_model = AutoModelForCausalLM.from_pretrained("meta-llama/Llama-3.1-8B-Instruct") +# Consider sharded/offloaded loading for large teachers: +teacher_model = AutoModelForCausalLM.from_pretrained( + "meta-llama/Llama-3.1-70B-Instruct", + device_map="auto" # or use accelerate/FSDP per your setup +)
65-66
: Clarify serialization behavior of an nn.Module teacher.Readers may wonder how checkpoints behave when passing a module instance. Add a short note pointing to the guide’s restoration semantics.
Proposed addition:
-The `teacher_model` can be either a `nn.Module`, a callable which returns an `nn.Module`, or a tuple of `(model_cls, args, kwargs)`. +The `teacher_model` can be either a `nn.Module`, a callable which returns an `nn.Module`, or a tuple of `(model_cls, args, kwargs)`. +Note: when saving, KD-specific state (including the teacher instance) is not re-instantiated on restore; see the Distillation guide for details.
76-82
: Fix variable name mismatch (train_loader vs train_dataloader).The code defines train_loader but iterates train_dataloader.
Apply this diff:
-for input, labels in train_dataloader: +for input, labels in train_loader:
14-15
: Typo: “intellegant” → “intelligent”.Minor doc polish.
Apply this diff:
-| Getting Started | Learn how to optimize your models using distillation to produce more intellegant smaller models | [[Link](#getting-started)] | [[docs](https://nvidia.github.io/TensorRT-Model-Optimizer/guides/4_distillation.html)] | +| Getting Started | Learn how to optimize your models using distillation to produce more intelligent smaller models | [[Link](#getting-started)] | [[docs](https://nvidia.github.io/TensorRT-Model-Optimizer/guides/4_distillation.html)] |docs/source/guides/4_distillation.rst (3)
19-22
: Make restore semantics actionable.Add a pointer that resuming KD requires re-converting with kd_loss if desired.
Apply this diff:
-Note that restoring the model (via :meth:`mto.restore <modelopt.torch.opt.conversion.restore>`) -will not reinstantiate the distillation meta-model, in order to avoid unpickling issues. +Note that restoring the model (via :meth:`mto.restore <modelopt.torch.opt.conversion.restore>`) +will not reinstantiate the distillation meta-model, in order to avoid unpickling issues. To +resume KD training, call :meth:`mtd.convert <modelopt.torch.distill.distillation.convert>` +again on the restored student with your desired ``kd_loss`` config.
56-57
: Clarify export comment.“Previously-present attributes” is vague.
Apply this diff:
-# Export model in original class, with only previously-present attributes +# Export the original student class; distillation-specific attributes are removed
61-61
: Grammar nit: “for to perform” → “to perform”.Apply this diff:
- When training the student on a small corpus of ground truth data, consider using :class:`MFTLoss <modelopt.torch.distill.MFTLoss>` for to perform Minifinetuning in lieu of the standard + When training the student on a small corpus of ground truth data, consider using :class:`MFTLoss <modelopt.torch.distill.MFTLoss>` to perform Minifinetuning in lieu of the standardexamples/llm_distill/main.py (1)
126-129
: Consider specifying dtype to reduce memory (esp. with bf16).Passing torch_dtype helps avoid unnecessary fp32 allocations.
Apply this diff:
- teacher_model = transformers.AutoModelForCausalLM.from_pretrained( - model_args.teacher_name_or_path, - device_map=PartialState().process_index, - ) + teacher_model = transformers.AutoModelForCausalLM.from_pretrained( + model_args.teacher_name_or_path, + device_map=PartialState().process_index, + torch_dtype=torch.bfloat16 if training_args.bf16 else None, + )modelopt/torch/distill/config.py (1)
81-89
: Handle None criterion without injecting a sentinel loss.Returning {(“”, “”): None} creates a dict with a non-Loss value. Prefer returning {} to keep types clean; strict validation already warns on empty criterion.
Apply this diff:
@pydantic.field_validator("criterion") @classmethod def format_criterion(cls, criterion: Criterion | None) -> dict[tuple[str, str], Loss]: """Ensure criterion is a mapping from layer names to loss (potentially entire module).""" - if not isinstance(criterion, dict): - # Output-only distillation. - criterion = {("", ""): criterion} - return criterion + if criterion is None: + return {} + if isinstance(criterion, dict): + return criterion + # Output-only distillation. + return {("", ""): criterion}modelopt/torch/distill/mode.py (3)
16-20
: Docstring mentions NAS; update to Distillation.Minor copy/paste artifact.
Apply this diff:
-"""Module implementing and describing modes that can be used during the NAS convert process. +"""Module implementing and describing modes that can be used during the Distillation convert process. -Check out :meth:`mtn.convert <modelopt.torch.nas.conversion.convert>` to learn more about modes. +Check out :meth:`mtd.convert <modelopt.torch.distill.distillation.convert>` to learn more about modes. """
177-181
: Optional: emit a debug log on no-op restore.A one-line debug helps users understand why KD wasn’t reconstructed.
Apply this diff:
def _restore_kd_model(model: nn.Module, config: KDLossConfig, metadata: MetadataDict) -> nn.Module: """Function for restoring a previously convert model to a distillation meta-model.""" - # NOTE: DistillationModel will purposely remain unrestored + # NOTE: DistillationModel will purposely remain unrestored + warnings.warn("KD mode state was sanitized at save time; skipping DistillationModel reconstruction.", stacklevel=1) return model
183-188
: Type hint and minimal doc nit for reset helper.Make the helper’s intent explicit; keep behavior unchanged.
Apply this diff:
-def _reset_kd_state_config(model: nn.Module, config: KDLossConfig, metadata: MetadataDict): - """Function for resetting the state's config.""" +def _reset_kd_state_config(model: nn.Module, config: KDLossConfig, metadata: MetadataDict) -> None: + """Reset KD-related config fields to pickle-friendly placeholders before new-mode addition or save."""
📜 Review details
Configuration used: CodeRabbit UI
Review profile: CHILL
Plan: Pro
📒 Files selected for processing (8)
docs/source/guides/4_distillation.rst
(4 hunks)examples/llm_distill/README.md
(2 hunks)examples/llm_distill/main.py
(2 hunks)modelopt/torch/distill/config.py
(2 hunks)modelopt/torch/distill/mode.py
(3 hunks)modelopt/torch/distill/plugins/huggingface.py
(0 hunks)modelopt/torch/quantization/plugins/transformers_trainer.py
(0 hunks)tests/unit/torch/distill/test_distill.py
(1 hunks)
💤 Files with no reviewable changes (2)
- modelopt/torch/quantization/plugins/transformers_trainer.py
- modelopt/torch/distill/plugins/huggingface.py
🧰 Additional context used
🧬 Code graph analysis (4)
tests/unit/torch/distill/test_distill.py (2)
modelopt/torch/opt/conversion.py (1)
ModeloptStateManager
(63-311)modelopt/torch/opt/config.py (1)
values
(136-138)
examples/llm_distill/main.py (1)
modelopt/torch/distill/distillation_model.py (1)
teacher_model
(124-126)
modelopt/torch/distill/config.py (2)
modelopt/torch/distill/distillation_model.py (1)
teacher_model
(124-126)modelopt/torch/opt/config.py (1)
ModeloptField
(50-53)
modelopt/torch/distill/mode.py (1)
modelopt/torch/distill/config.py (1)
KDLossConfig
(34-120)
⏰ Context from checks skipped due to timeout of 90000ms. You can increase the timeout in your CodeRabbit configuration to a maximum of 15 minutes (900000ms). (1)
- GitHub Check: wait-checks / wait
🔇 Additional comments (8)
examples/llm_distill/README.md (1)
57-60
: Direct nn.Module acceptance is clear.Good change; the example demonstrates the broadened API for teacher_model correctly.
examples/llm_distill/main.py (2)
113-115
: LGTM: direct single-model load aligns with README.Using AutoModelForCausalLM directly here improves clarity and removes indirection.
131-135
: LGTM: pass instantiated teacher into kd_config.This matches the broadened ModelLike typing and the docs.
tests/unit/torch/distill/test_distill.py (2)
150-156
: Good assertions for sanitized KD state after restore.Asserting placeholders for teacher_model, criterion, and loss_balancer precisely tests the new save/restore semantics.
157-159
: LGTM: restored object remains the raw student type.This validates the intended “don’t reconstruct DistillationModel” behavior.
modelopt/torch/distill/config.py (1)
43-51
: Broadened ModelLike support is appropriate.Accepting nn.Module in addition to class/callable/tuple simplifies usage and is consistent with init_model_from_model_like.
modelopt/torch/distill/mode.py (1)
84-92
: Pre-save hook is a solid addition.Exposing update_for_save and pointing it to the reset function centralizes KD-state sanitization before checkpointing.
docs/source/guides/4_distillation.rst (1)
205-210
: External reference verified — "Minifinetuning" is correct.
Paper: "Minifinetuning: Low-Data Generation Domain Adaptation through Corrective Self-Distillation" — Peter Belcak, Greg Heinrich, Jan Kautz, Pavlo Molchanov; arXiv:2506.15702.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Actionable comments posted: 2
Caution
Some comments are outside the diff and can’t be posted inline due to platform limitations.
⚠️ Outside diff range comments (1)
examples/llm_distill/main.py (1)
83-91
: world_size lookup can crash when DDP not initialized; avoid float division.torch.distributed.get_world_size() raises if the process group isn’t initialized, and float is_integer checks are brittle.
Apply:
- num_accum_steps = total_batch_size / ( - training_args.per_device_train_batch_size * torch.distributed.get_world_size() - ) - if not num_accum_steps.is_integer(): - raise ValueError( - f"`per_device_train_batch_size` * `world_size` must be a factor of {total_batch_size}" - ) - training_args.gradient_accumulation_steps = int(num_accum_steps) + world_size = PartialState().num_processes + per_step = training_args.per_device_train_batch_size * world_size + if total_batch_size % per_step != 0: + raise ValueError( + f"`per_device_train_batch_size` * `world_size` ({per_step}) must divide {total_batch_size}" + ) + training_args.gradient_accumulation_steps = total_batch_size // per_step
🧹 Nitpick comments (3)
examples/llm_distill/main.py (2)
51-51
: max_length is not consumed by SFTTrainer; wire it or revert.Pass it as SFTTrainer’s max_seq_length to avoid silently using the default.
Apply:
class KDSFTTrainer(SFTTrainer, KDTrainer): pass @@ trainer = trainer_cls( model, training_args, train_dataset=dset_train, eval_dataset=dset_eval, formatting_func=llama_text_format_func, processing_class=tokenizer, + max_seq_length=training_args.max_length, )
62-66
: Formatting func: guard missing keys and trim whitespace.Use .get(...) to avoid KeyError and normalize system prompt.
Apply:
- p, q, r = sample["system_prompt"], sample["question"], sample["response"] - if not p: - return f"<s>[INST] {q}[/INST]\n{r}</s>" - else: - return f"<s>[INST] <<SYS>>{p}<</SYS>>\n{q}[/INST]\n{r}</s>" + p = (sample.get("system_prompt") or "").strip() + q = sample["question"] + r = sample["response"] + if p: + return f"<s>[INST] <<SYS>>{p}<</SYS>>\n{q}[/INST]\n{r}</s>" + return f"<s>[INST] {q}[/INST]\n{r}</s>"examples/llm_distill/requirements.txt (1)
2-2
: Pin HF stack for TRL 0.23.0 (avoid resolver drift)TRL's PyPI metadata shows accelerate>=1.4.0, transformers>=4.56.1, peft>=0.8.0 — add explicit pins or a constraints file so examples/llm_distill/requirements.txt (currently contains trl==0.23.0) also locks compatible versions (suggested: transformers==4.56.1, accelerate==1.4.0, peft==0.8.0).
📜 Review details
Configuration used: CodeRabbit UI
Review profile: CHILL
Plan: Pro
📒 Files selected for processing (3)
examples/llm_distill/README.md
(3 hunks)examples/llm_distill/main.py
(3 hunks)examples/llm_distill/requirements.txt
(1 hunks)
🚧 Files skipped from review as they are similar to previous changes (1)
- examples/llm_distill/README.md
🧰 Additional context used
🧬 Code graph analysis (1)
examples/llm_distill/main.py (1)
modelopt/torch/distill/distillation_model.py (1)
teacher_model
(124-126)
⏰ Context from checks skipped due to timeout of 90000ms. You can increase the timeout in your CodeRabbit configuration to a maximum of 15 minutes (900000ms). (4)
- GitHub Check: linux
- GitHub Check: wait-checks / wait
- GitHub Check: build-docs
- GitHub Check: code-quality
🔇 Additional comments (1)
examples/llm_distill/main.py (1)
128-131
: Passing the instantiated teacher into kd_config looks right.This aligns with the new ModelLike teacher_model semantics and the PR goal to avoid saving problematic KD wrapper state.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Actionable comments posted: 1
♻️ Duplicate comments (4)
examples/llm_distill/README.md (1)
172-184
: Same here: keep FSDP2 knobs inside the YAML.Mirror the previous change for the distillation command.
Apply:
-accelerate launch --config-file ./accelerate_config/fsdp2.yaml \ - --fsdp_cpu_ram_efficient_loading False \ - --fsdp_activation_checkpointing False \ - main.py \ +accelerate launch --config-file ./accelerate_config/fsdp2.yaml \ + main.py \ --teacher_name_or_path ./llama2-7b-sft \ --student_name_or_path 'TinyLlama/TinyLlama-1.1B-intermediate-step-1431k-3T' \ --output_dir ./llama2-distill \ --max_length 2048 \ --per_device_train_batch_size 1 \ --per_device_eval_batch_size 4 \ --max_steps 200 \ --logging_steps 5examples/llm_distill/main.py (3)
109-112
: Optional: enable memory‑friendly loading (and consider explicit device_map).For large models, add
low_cpu_mem_usage=True
. If you hit meta‑device issues in distributed setups, use an explicit device map or rely on accelerate’s placement. This mirrors prior review feedback.Apply:
- model = transformers.AutoModelForCausalLM.from_pretrained( - model_path, torch_dtype=torch.bfloat16 if training_args.bf16 else None - ) + model = transformers.AutoModelForCausalLM.from_pretrained( + model_path, + torch_dtype=torch.bfloat16 if training_args.bf16 else None, + low_cpu_mem_usage=True, + ) @@ - model = transformers.AutoModelForCausalLM.from_pretrained( - model_args.student_name_or_path, torch_dtype=torch.bfloat16 if training_args.bf16 else None - ) + model = transformers.AutoModelForCausalLM.from_pretrained( + model_args.student_name_or_path, + torch_dtype=torch.bfloat16 if training_args.bf16 else None, + low_cpu_mem_usage=True, + ) @@ - teacher_model = transformers.AutoModelForCausalLM.from_pretrained( - model_args.teacher_name_or_path, torch_dtype=torch.bfloat16 if training_args.bf16 else None - ) + teacher_model = transformers.AutoModelForCausalLM.from_pretrained( + model_args.teacher_name_or_path, + torch_dtype=torch.bfloat16 if training_args.bf16 else None, + low_cpu_mem_usage=True, + )Also applies to: 115-117, 121-123
107-112
: Bug:from_pretrained
usestorch_dtype
, notdtype
.This will raise
TypeError: got an unexpected keyword argument 'dtype'
. Usetorch_dtype=...
.Apply:
- model = transformers.AutoModelForCausalLM.from_pretrained( - model_path, dtype=torch.bfloat16 if training_args.bf16 else None - ) + model = transformers.AutoModelForCausalLM.from_pretrained( + model_path, torch_dtype=torch.bfloat16 if training_args.bf16 else None + )
121-123
: Sametorch_dtype
fix for teacher load.Apply:
- teacher_model = transformers.AutoModelForCausalLM.from_pretrained( - model_args.teacher_name_or_path, dtype=torch.bfloat16 if training_args.bf16 else None - ) + teacher_model = transformers.AutoModelForCausalLM.from_pretrained( + model_args.teacher_name_or_path, torch_dtype=torch.bfloat16 if training_args.bf16 else None + )
🧹 Nitpick comments (8)
tests/unit/torch/opt/plugins/test_hf_patching.py (2)
40-44
: Direct teacher instantiation LGTM; keep dtype/device consistent.Good move replacing the factory with a concrete teacher instance. Consider ensuring teacher and student use the same dtype/device in tests to avoid accidental dtype/device mismatches when kernels run on CI with different defaults. A simple
teacher_model.to(model_ref.dtype).eval()
before convert is sufficient.
56-58
: Add an assertion that KD state is not re‑materialized after restore.Given the PR’s goal (avoid saving problematic KD state), add a check that the reloaded model is a plain base model without KD wrappers and that no KD config is present in saved state.
Apply:
tf_output_tester(model, model_test) # since distill model contains loss function, we compare state of model manually assert mto.modelopt_state(model.model) == mto.modelopt_state(model_test.model) + +# Also verify KD metadata is not persisted/reconstructed +state = mto.modelopt_state(model_test.model) +assert not any(k.startswith("kd_") for k in state.keys())examples/llm_distill/README.md (3)
42-45
: Example matches new API; add note about memory‑friendly loading.Since large models are used here, add a brief note (or code) to pass
low_cpu_mem_usage=True
(and optionally a device map) to reduce CPU spikes duringfrom_pretrained
.Apply:
-student_model = AutoModelForCausalLM.from_pretrained("meta-llama/Llama-3.1-8B-Instruct") -teacher_model = AutoModelForCausalLM.from_pretrained("meta-llama/Llama-3.1-70B-Instruct") +student_model = AutoModelForCausalLM.from_pretrained( + "meta-llama/Llama-3.1-8B-Instruct", low_cpu_mem_usage=True +) +teacher_model = AutoModelForCausalLM.from_pretrained( + "meta-llama/Llama-3.1-70B-Instruct", low_cpu_mem_usage=True +)
57-60
: Docs: clarify loss balancer recommendation.Mention that when only KD loss is used, omitting
loss_balancer
makes the KD loss the total loss by default; otherwise, show a weight (e.g.,StaticLossBalancer(kd_weight=1.0, student_weight=1.0)
).
157-167
: Avoid non‑portable accelerate flags; keep them in YAML.
--fsdp_cpu_ram_efficient_loading
and--fsdp_activation_checkpointing
are config keys, not stable CLI flags across accelerate versions. Recommend removing them from the command and keeping overrides inaccelerate_config/fsdp2.yaml
to prevent CI breakage.Apply:
-accelerate launch --config-file ./accelerate_config/fsdp2.yaml \ - main.py \ +accelerate launch --config-file ./accelerate_config/fsdp2.yaml \ + main.py \ --single_model \ --teacher_name_or_path 'meta-llama/Llama-2-7b-hf' \ --output_dir ./llama2-7b-sft \ --max_length 2048 \ --per_device_train_batch_size 1 \ --per_device_eval_batch_size 4 \ --max_steps 400 \ --logging_steps 5tests/examples/llm_distill/test_llm_distill.py (1)
24-27
: Make test independent of accelerate CLI internals.Drop accelerate‑specific FSDP flags from the CLI; keep them solely in
accelerate_config/fsdp2.yaml
to reduce flakiness across versions.Apply:
- "accelerate", "launch", "--config-file", "./accelerate_config/fsdp2.yaml", - "--fsdp_cpu_ram_efficient_loading", "False", - "--fsdp_activation_checkpointing", "False", - "main.py", + "accelerate", "launch", "--config-file", "./accelerate_config/fsdp2.yaml", + "main.py",examples/llm_distill/main.py (2)
80-90
: Guard world size when not launched under accelerate.
torch.distributed.get_world_size()
errors if the process group isn’t initialized. Add a safe fallback.Apply:
- num_accum_steps = total_batch_size / ( - training_args.per_device_train_batch_size * torch.distributed.get_world_size() - ) + world_size = torch.distributed.get_world_size() if torch.distributed.is_initialized() else 1 + num_accum_steps = total_batch_size / ( + training_args.per_device_train_batch_size * world_size + )
124-129
: Freeze teacher to avoid accidental gradients.Explicitly set
teacher_model.requires_grad_(False)
to prevent optimizer from touching it if a custom trainer is used.Apply:
kd_config = { "teacher_model": teacher_model, "criterion": LMLogitsLoss(), } + teacher_model.requires_grad_(False).eval()
📜 Review details
Configuration used: CodeRabbit UI
Review profile: CHILL
Plan: Pro
📒 Files selected for processing (6)
examples/llm_distill/README.md
(3 hunks)examples/llm_distill/accelerate_config/fsdp2.yaml
(1 hunks)examples/llm_distill/main.py
(2 hunks)examples/llm_distill/requirements.txt
(1 hunks)tests/examples/llm_distill/test_llm_distill.py
(1 hunks)tests/unit/torch/opt/plugins/test_hf_patching.py
(2 hunks)
✅ Files skipped from review due to trivial changes (1)
- examples/llm_distill/accelerate_config/fsdp2.yaml
🚧 Files skipped from review as they are similar to previous changes (1)
- examples/llm_distill/requirements.txt
🧰 Additional context used
🧬 Code graph analysis (3)
tests/examples/llm_distill/test_llm_distill.py (1)
tests/examples/conftest.py (1)
tiny_llama_path
(33-41)
tests/unit/torch/opt/plugins/test_hf_patching.py (2)
modelopt/torch/distill/distillation_model.py (1)
teacher_model
(124-126)tests/_test_utils/torch_model/transformers_models.py (1)
get_tiny_qwen3
(44-60)
examples/llm_distill/main.py (1)
modelopt/torch/distill/distillation_model.py (1)
teacher_model
(124-126)
⏰ Context from checks skipped due to timeout of 90000ms. You can increase the timeout in your CodeRabbit configuration to a maximum of 15 minutes (900000ms). (1)
- GitHub Check: wait-checks / wait
🔇 Additional comments (2)
tests/examples/llm_distill/test_llm_distill.py (1)
31-35
: Flag rename looks correct.
--max_seq_length
→--max_length
aligns with the updatedTrainingArguments
. Good.examples/llm_distill/main.py (1)
132-134
: Confirm None is valid for generation fields across transformers versions.Setting
temperature
/top_p
toNone
is okay on recent versions; older versions expect floats. If you need broader compatibility, delete the attributes instead of assigningNone
.
Signed-off-by: Asha Anoosheh <[email protected]>
Signed-off-by: Asha Anoosheh <[email protected]>
Signed-off-by: Asha Anoosheh <[email protected]>
Signed-off-by: Asha Anoosheh <[email protected]>
3dcd039
to
549d20f
Compare
Codecov Report✅ All modified and coverable lines are covered by tests. Additional details and impacted files@@ Coverage Diff @@
## main #320 +/- ##
==========================================
- Coverage 73.82% 73.82% -0.01%
==========================================
Files 172 172
Lines 17438 17437 -1
==========================================
- Hits 12874 12872 -2
- Misses 4564 4565 +1 ☔ View full report in Codecov by Sentry. 🚀 New features to boost your workflow:
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Actionable comments posted: 1
Caution
Some comments are outside the diff and can’t be posted inline due to platform limitations.
⚠️ Outside diff range comments (2)
modelopt/torch/distill/mode.py (1)
183-188
: Use a concrete Loss to avoid accidental runtime errors
Loss()
is the abstract base; calling it would raise at runtime if ever used. Prefer a harmless concrete default.def _reset_kd_state_config(model: nn.Module, config: KDLossConfig, metadata: MetadataDict): """Function for resetting the state's config.""" config.teacher_model = nn.Module - config.criterion = Loss() + # Use a concrete, parameter-free loss to avoid runtime errors if ever invoked. + config.criterion = nn.MSELoss() config.loss_balancer = Noneexamples/llm_distill/main.py (1)
82-90
: Single-process robustness: guard world_size when not initialized
torch.distributed.get_world_size()
errors if the process group isn’t initialized (e.g., local runs).- num_accum_steps = total_batch_size / ( - training_args.per_device_train_batch_size * torch.distributed.get_world_size() - ) + world_size = ( + torch.distributed.get_world_size() + if torch.distributed.is_available() and torch.distributed.is_initialized() + else 1 + ) + num_accum_steps = total_batch_size / ( + training_args.per_device_train_batch_size * world_size + )
♻️ Duplicate comments (2)
examples/llm_distill/main.py (2)
116-117
: Same kwarg bug for student load: use torch_dtype, not dtypeMatches the above issue.
- model = transformers.AutoModelForCausalLM.from_pretrained( - model_args.student_name_or_path, dtype=torch.bfloat16 if training_args.bf16 else None - ) + model = transformers.AutoModelForCausalLM.from_pretrained( + model_args.student_name_or_path, torch_dtype=torch.bfloat16 if training_args.bf16 else None + )
109-111
: Replace invalid from_pretrained kwargdtype
withtorch_dtype
from_pretrained does not accept
dtype
; usetorch_dtype
to set model precision.File: examples/llm_distill/main.py (lines 109-111)
- model = transformers.AutoModelForCausalLM.from_pretrained( - model_path, dtype=torch.bfloat16 if training_args.bf16 else None - ) + model = transformers.AutoModelForCausalLM.from_pretrained( + model_path, torch_dtype=torch.bfloat16 if training_args.bf16 else None + )
🧹 Nitpick comments (2)
examples/llm_distill/main.py (2)
61-66
: Guard dataset fields to avoid KeyError/None formattingOpen-Orca variants can miss
system_prompt
. Use.get
and empty-string fallbacks.-def llama_text_format_func(sample): - p, q, r = sample["system_prompt"], sample["question"], sample["response"] - if not p: - return f"<s>[INST] {q}[/INST]\n{r}</s>" - else: - return f"<s>[INST] <<SYS>>{p}<</SYS>>\n{q}[/INST]\n{r}</s>" +def llama_text_format_func(sample): + p = (sample.get("system_prompt") or "").strip() + q = (sample.get("question") or "").strip() + r = (sample.get("response") or "").strip() + if not p: + return f"<s>[INST] {q}[/INST]\n{r}</s>" + return f"<s>[INST] <<SYS>>{p}<</SYS>>\n{q}[/INST]\n{r}</s>"Please confirm the split’s actual field names (some dumps use “system_prompt”, some “system_prompt_template”, etc.).
137-144
: Pass dataset_num_proc to SFTTrainer for faster preprocessingYou defined it in TrainingArguments but don’t pass it; wiring it through speeds tokenization.
trainer = trainer_cls( model, training_args, train_dataset=dset_train, eval_dataset=dset_eval, formatting_func=llama_text_format_func, processing_class=tokenizer, + dataset_num_proc=training_args.dataset_num_proc, )
📜 Review details
Configuration used: CodeRabbit UI
Review profile: CHILL
Plan: Pro
📒 Files selected for processing (12)
docs/source/guides/4_distillation.rst
(4 hunks)examples/llm_distill/README.md
(3 hunks)examples/llm_distill/accelerate_config/fsdp2.yaml
(1 hunks)examples/llm_distill/main.py
(2 hunks)examples/llm_distill/requirements.txt
(1 hunks)modelopt/torch/distill/config.py
(2 hunks)modelopt/torch/distill/mode.py
(3 hunks)modelopt/torch/distill/plugins/huggingface.py
(0 hunks)modelopt/torch/quantization/plugins/transformers_trainer.py
(0 hunks)tests/examples/llm_distill/test_llm_distill.py
(1 hunks)tests/unit/torch/distill/test_distill.py
(1 hunks)tests/unit/torch/opt/plugins/test_hf_patching.py
(2 hunks)
💤 Files with no reviewable changes (2)
- modelopt/torch/quantization/plugins/transformers_trainer.py
- modelopt/torch/distill/plugins/huggingface.py
🚧 Files skipped from review as they are similar to previous changes (8)
- docs/source/guides/4_distillation.rst
- examples/llm_distill/accelerate_config/fsdp2.yaml
- modelopt/torch/distill/config.py
- tests/unit/torch/opt/plugins/test_hf_patching.py
- examples/llm_distill/requirements.txt
- tests/examples/llm_distill/test_llm_distill.py
- tests/unit/torch/distill/test_distill.py
- examples/llm_distill/README.md
🧰 Additional context used
🧬 Code graph analysis (2)
modelopt/torch/distill/mode.py (1)
modelopt/torch/distill/config.py (1)
KDLossConfig
(34-120)
examples/llm_distill/main.py (1)
modelopt/torch/distill/distillation_model.py (1)
teacher_model
(124-126)
⏰ Context from checks skipped due to timeout of 90000ms. You can increase the timeout in your CodeRabbit configuration to a maximum of 15 minutes (900000ms). (4)
- GitHub Check: linux
- GitHub Check: wait-checks / wait
- GitHub Check: code-quality
- GitHub Check: build-docs
🔇 Additional comments (6)
modelopt/torch/distill/mode.py (4)
86-87
: Unifying pre-new-mode reset via _reset_kd_state_config — OKCentralizing the reset path here looks good and simplifies lifecycle handling.
179-181
: Restore is intentionally a no-op — OKMatches the documented decision to avoid re-instantiation of the distillation wrapper on restore.
Please ensure docs and error messages clearly state that KD is not reconstructed on restore and must be re-converted if needed.
220-221
: Exported-student restore no-op — OKConsistent with the KD restore behavior; no objections.
88-91
: Pre-save hook mutates config in-place; avoid post-save side-effectsupdate_for_save in KD returns _reset_kd_state_config (modelopt/torch/distill/mode.py:88–91) but conversion calls last_mode.update_for_save(model, last_config, self._last_metadata) on the stored _last_config (modelopt/torch/opt/conversion.py:310). If _reset_kd_state_config mutates teacher/criterion/loss_balancer (or other config fields) in-place, a save performed mid-training will change runtime behavior after saving. Ensure update_for_save either works on a deepcopy or restores the original config (or change the caller to pass a deepcopy) so saves have no side effects.
examples/llm_distill/main.py (2)
50-50
: Rename to max_length — OKTracks recent TRL/Transformers arg conventions.
100-105
: Tokenizer pad token may be NoneSome LLM tokenizers lack
eos_token
/pad_token
. Consider adding a pad token if missing to avoid collation errors.- tokenizer.pad_token = tokenizer.eos_token + if tokenizer.pad_token is None: + tokenizer.pad_token = tokenizer.eos_token + # If still None, add a pad token and resize later if needed. + if tokenizer.pad_token is None: + tokenizer.add_special_tokens({"pad_token": "<pad>"})
teacher_model = transformers.AutoModelForCausalLM.from_pretrained( | ||
model_args.teacher_name_or_path, dtype=torch.bfloat16 if training_args.bf16 else None | ||
) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Same kwarg bug for teacher load: use torch_dtype
Keep teacher/student loads consistent.
- teacher_model = transformers.AutoModelForCausalLM.from_pretrained(
- model_args.teacher_name_or_path, dtype=torch.bfloat16 if training_args.bf16 else None
- )
+ teacher_model = transformers.AutoModelForCausalLM.from_pretrained(
+ model_args.teacher_name_or_path, torch_dtype=torch.bfloat16 if training_args.bf16 else None
+ )
📝 Committable suggestion
‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.
teacher_model = transformers.AutoModelForCausalLM.from_pretrained( | |
model_args.teacher_name_or_path, dtype=torch.bfloat16 if training_args.bf16 else None | |
) | |
teacher_model = transformers.AutoModelForCausalLM.from_pretrained( | |
model_args.teacher_name_or_path, torch_dtype=torch.bfloat16 if training_args.bf16 else None | |
) |
🤖 Prompt for AI Agents
In examples/llm_distill/main.py around lines 121 to 123, the teacher model is
loaded using the incorrect keyword argument dtype= when calling
transformers.AutoModelForCausalLM.from_pretrained; change that to torch_dtype=
and pass torch.bfloat16 if training_args.bf16 else None so the teacher load
matches the student load and uses the correct HF transformers parameter.
What does this PR do?
Feature: Change the way "kd_loss" mode saves state
Overview: ?
Usage
# Add a code snippet demonstrating how to use this
Testing
Unit
Before your PR is "Ready for review"
Additional Information
Summary by CodeRabbit
New Features
Refactor
Documentation
Tests